"""Generate Gaussian distribution of point loads."""
import logging
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)

def main():

    opts = read_cli()

    loadfilename = ("gauss_exc_sigma_%.3f_%.3f_%.3f_"
                    "center_%.3f_%.3f_%.3f_amp_%.3f_amp_cut_%.3f_%s.dyn" %
                    (opts.sigma[0], opts.sigma[1], opts.sigma[2],
                     opts.center[0], opts.center[1], opts.center[2],
                     opts.amp, opts.amp_cut, opts.sym))

    generate_loads(opts.sigma, opts.center, opts.amp, opts.amp_cut, opts.sym,
                   opts.direction, loadfilename, opts.nodefile)

    return 0


def generate_loads(sigma, center, amp=1.0, amp_cut=0.05, sym="qsym",
                   direction=-3, loadfilename="loads.dyn",
                   nodefile="nodes.dyn", tukey_length=0.0, tukey_alpha=0.25):
    """

    Args:
      sigma:
      center:
      amp:  (Default value = 1.0)
      amp_cut:  (Default value = 0.05)
      sym:  (Default value = "qsym")
      direction:  (Default value = -3)
      loadfilename:  (Default value = "loads.dyn")
      nodefile:  (Default value = "nodes.dyn")
      tukey_length:  (Default value = 0.0)
      tukey_alpha:  (Default value = 0.25)

    Returns:

    """
    load_nodeID_amp = read_process_nodes(sigma, center, sym, amp, amp_cut,
                                         nodefile, tukey_length, tukey_alpha)
    write_load_file(loadfilename, load_nodeID_amp, direction,
                    header_comment="$Generated by GaussExc.py\n")
    return 0


def read_process_nodes(sigma, center, sym="qsym", amp=1.0, amp_cut=0.05,
                       nodefile="nodes.dyn", tukey_length=0.0,
                       tukey_alpha=0.25):
    """read nodes and determine if point load should be applied based on
    spatial position

    Args:
      sigma: param center:
      sym: default = 'qsym'
      amp: amplitude (default = 1.0)
      amp_cut: lower amplitude threshold (default = 0.05)
      nodefile: default = 'nodes.dyn'
      tukey_length: length of Tukey window (0.0 defaults to Gaussian)
      tukey_alpha: percentage of Tukey window to rolloff (Default value = 0.25)
      center:

    Returns:
      load_nodeID_amp list of nodeID, nodeGaussAmp

    """

    with open(nodefile, 'r') as NODEFILE:
        load_nodeID_amp = []
        for i in NODEFILE:
            fields = read_node_positions(i)

            if fields:
                if tukey_length == 0.0:
                    nodeGaussAmp = calc_gauss_amp(fields, center, sigma, amp,
                                                  amp_cut, sym)
                else:
                    nodeGaussAmp = calc_tukey_amp(fields, center, sigma,
                                                  tukey_length, tukey_alpha,
                                                  amp, amp_cut, sym)

                if nodeGaussAmp:
                    load_nodeID_amp.append((int(fields[0]), nodeGaussAmp))

    return load_nodeID_amp


def write_load_file(loadfilename, load_nodeID_amp, direction=-3,
                    header_comment="$Generated by GaussExc.py\n"):
    """write load file

    Args:
      loadfilename: param load_nodeID_amp: list of int node ID, float amp
      direction: default = -3 (orientation (1, 2, 3) and sign)
      header_comment:  (Default value = "$Generated by GaussExc.py\n")
      load_nodeID_amp:

    Returns:

    """
    import numpy as np
    d = np.abs(direction)
    dsign = np.sign(direction)

    with open(loadfilename, 'w') as lfile:
        lfile.write(header_comment)
        lfile.write("*LOAD_NODE_POINT\n")
        [lfile.write(f"{i},{d},1,{dsign * j:.4f}\n") for i, j in load_nodeID_amp]
        lfile.write("*END\n")

    return 0


def read_node_positions(line):
    """read node position fields from line in nodefile

    Ignore lines that start with '$' (comments) and '*' keywords;
    Assume that all data entries are comma-delimited without spaces.

    Args:
      line: single line string from nodefile

    Returns:
      fields: 1x4 float list of node ID, x, y, z; None if appropriate

    """
    if line[0] != "$" and line[0] != "*":
        line = line.rstrip('\n')
        fields = line.split(',')
        fields = [float(i) for i in fields]
        check_num_fields(fields)
    else:
        fields = None

    return fields


def check_num_fields(fields):
    """check for 4 fields

    Args:
      fields: list (node ID, x, y, z)

    Returns:

    """
    import sys 

    if len(fields) != 4:
        logger.error("Unexpected number of node columns.")
        raise SyntaxError("Unexpected number of node columns.")
        sys.exit(1)
    else:
        return 0


def sym_scale_amp(fields, nodeGaussAmp, sym, search_tol=0.0001):
    """scale point load amplitude on symmetry faces / edges

    Args:
        fields (list): node ID, x, y, z
        nodeGaussAmp (float): amplitude of point load
        sym (str): type of mesh symmetry (none, qsym, hsym)
        search_tol (float): spatial tolerance to find nearby nodes

    Returns:
        nodeGaussAmp: symmetry-scaled point load amplitude

    """
    from math import fabs
    import sys

    if sym == 'qsym':
        if (fabs(fields[1]) < search_tol and fabs(fields[2]) < search_tol):
            nodeGaussAmp = nodeGaussAmp / 4
        elif (fabs(fields[1]) < search_tol or fabs(fields[2]) < search_tol):
            nodeGaussAmp = nodeGaussAmp / 2
    elif sym == 'hsym':
        if fabs(fields[1]) < search_tol:
            nodeGaussAmp = nodeGaussAmp / 2
    elif sym != 'none':
        sys.exit('ERROR: Invalid symmetry option specified.')

    return nodeGaussAmp


def calc_gauss_amp(node_xyz, center=(0.0, 0.0, -2.0), sigma=(1.0, 1.0, 1.0),
                   amp=1.0, amp_cut=0.05, sym="qsym"):
    """calculated the Gaussian amplitude at the node

    Args:
      node_xyz: list of x,y,z node coordinates
      center: list of x,y,z for Gaussian center (Default value = (0.0)
      sigma: list of x,y,z Guassian width
      amp: peak Gaussian source amplitude
      amp_cut: lower threshold (pct of max) for amplitude creating a
    point load
      qsym: mesh symemetry (qsym, hsym, none)
      0.0:
      -2.0:

    Returns:
      nodeGaussAmp - point load amplitude at the specified node

    """
    from math import pow, exp
    exp1 = pow((node_xyz[1] - center[0]) / sigma[0], 2)
    exp2 = pow((node_xyz[2] - center[1]) / sigma[1], 2)
    exp3 = pow((node_xyz[3] - center[2]) / sigma[2], 2)
    nodeGaussAmp = amp * exp(-(exp1 + exp2 + exp3))

    if (nodeGaussAmp / amp) < amp_cut:
        nodeGaussAmp = None
    else:
        nodeGaussAmp = sym_scale_amp(node_xyz, nodeGaussAmp, sym)

    return nodeGaussAmp


def calc_tukey_amp(node_xyz, center=(0.0, 0.0, -2.0), sigma=(1.0, 1.0),
                   tukey_length=1.0, tukey_alpha=0.25, amp=1.0,
                   amp_cut=0.05, sym="qsym"):
    """calculated the Gaussian amplitude at the node

    Args:
      node_xyz: list of x,y,z node coordinates
      center: list of x,y,z for excitation center (Default value = (0.0)
      sigma: list of x,y Guassian width
      tukey_length: length of axial extent, centered at center
      tukey_alpha: percentage of rolloff (see scipy documentation)
      amp: peak Gaussian source amplitude
      amp_cut: lower threshold (pct of max) for amplitude creating a
    point load
      qsym: mesh symemetry (qsym, hsym, none)
      0.0:
      -2.0:

    Returns:
      nodeGaussAmp - point load amplitude at the specified node

    """
    from math import pow, exp

    exp1 = pow((node_xyz[1] - center[0]) / sigma[0], 2)
    exp2 = pow((node_xyz[2] - center[1]) / sigma[1], 2)

    z_scale = tukey_z_scale(node_xyz[3], center[2], tukey_length, tukey_alpha)

    nodeGaussAmp = amp * exp(-(exp1 + exp2)) * z_scale

    if (nodeGaussAmp / amp) < amp_cut:
        nodeGaussAmp = None
    else:
        nodeGaussAmp = sym_scale_amp(node_xyz, nodeGaussAmp, sym)

    return nodeGaussAmp


def tukey_z_scale(z, center, length, alpha=0.25, points=101):
    """

    Args:
      z: z-coordinate
      center: center of Tukey window
      length: length of Tukey window
      alpha: rolloff (percentage of window) (Default value = 0.25)
      points: number of points in Tukey window (Default value = 101)

    Returns:
      z_scale (scale, relative to 1.0)

    """
    import numpy as np
    from scipy.signal import tukey

    z = np.abs(z)
    zmin = np.abs(center) - length / 2
    zmax = np.abs(center) + length / 2
    z_tukey_win = np.linspace(zmin, zmax, points)
    z_tukey_amp = tukey(points, alpha)
    if z < zmin or z > zmax:
        z_scale = 0.0
    else:
        z_scale = z_tukey_amp[np.min(np.where(z_tukey_win >= z))]

    return z_scale


def read_cli():
    """read CLI arguments"""
    import argparse as ap

    p = ap.ArgumentParser(description="Generate *LOAD_NODE_POINT data "
                          "with Gaussian weighting about dim1 = 0, "
                          "dim2 = 0, extending through dim3.  All "
                          "spatial units are in the unit system for the "
                          "node definitions.",
                          formatter_class=ap.ArgumentDefaultsHelpFormatter)
    p.add_argument("--nodefile",
                   help="Node definition file (*.dyn)",
                   default="nodes.dyn")
    p.add_argument("--sigma",
                   type=float,
                   help="Standard devisions in 3 dims",
                   nargs=3,
                   default=(1.0, 1.0, 1.0))
    p.add_argument("--amp",
                   type=float,
                   help="Peak Gaussian amplitude",
                   default=1.0)
    p.add_argument("--direction",
                   type=int,
                   help="direction of load",
                   default=-3)
    p.add_argument("--amp_cut",
                   type=float,
                   help="Cutoff from peak amplitude to discard (so a lot "
                   "of the nodes don't have negligible loads on them)",
                   default=0.05)
    p.add_argument("--center",
                   type=float,
                   help="Gaussian center",
                   nargs=3,
                   default=(0.0, 0.0, -2.0))
    p.add_argument("--search_tol",
                   type=float,
                   help="Node search tolerance",
                   default=0.0001)
    p.add_argument("--sym",
                   help="Mesh symmetry (qsym or hsym)",
                   default="qsym")

    opts = p.parse_args()

    return opts

if __name__ == "__main__":
    main()
